RL: Training Inference Mismatch
基本概念
KL 散度
Reverse KL Divergence
GRPO 使用的是 Reverse KL(采样自当前策略 π_θ,参考策略为 π_ref):
$$D_{KL}(\pi_\theta | \pi_{ref}) = \mathbb{E}{x \sim \pi_\theta}\left[\log\frac{\pi_\theta(x)}{\pi{ref}(x)}\right]$$
在 LLM 场景下,按 token 级别计算:
1 | KL = Σ_t π_θ(o_t | q, o_<t) · log[π_θ(o_t | q, o_<t) / π_ref(o_t | q, o_<t)] |
三种 KL 估计方法
由于直接计算 KL 的期望成本高,实践中采用蒙特卡洛近似。常见有三种估计器 [[49]][[50]]:
| 估计器 | 公式 | 特性 | 适用场景 |
|---|---|---|---|
| k1 | -log(r),其中 r = π_ref/π_θ |
无偏但方差极大,梯度不含 π_ref | PPO 中作为 reward shaping,不适合作为独立 KL loss |
| k2 | 0.5 * (log(r))² |
有偏但方差低,梯度等价于 Reverse KL | ✅ GRPO 推荐(VeRL/TRL 默认) |
| k3 | (r - 1) - log(r) |
无偏、方差低,但梯度等价于 Forward KL | 需注意:采样分布不匹配时可能不稳定 |
关键代码逻辑(VeRL/TRL 实现):
1 | # 假设已获取 per-token logps |
🔍 为什么 k2 更推荐?
- k2 的梯度:
∇_θ [0.5*(log r)²] = (log r) · ∇_θ log π_θ,恰好匹配 Reverse KL 的实用梯度形式 [[49]]- k3 虽然无偏,但其梯度对应 Forward KL,在 π_θ 与 π_ref 差距较大时,重要性采样权重
π_ref/π_θ可能爆炸,导致训练不稳定
监控指标的计算流程(每步 RL)
1 | 1️⃣ 采样阶段: |
四、实践建议
配置选择(以 VeRL 为例)[[13]][[41]]:
1
2
3
4
5actor_rollout_ref:
actor:
use_kl_loss: true # 启用 KL 正则(GRPO 必须)
kl_loss_coef: 0.001 # β 系数,数学任务可增至 0.04
kl_loss_type: "k2" # 推荐 k2;若用 k3+ 可开启 straight-through 梯度修正监控阈值参考:
- KL < 0.01:策略变化过小,可能学习缓慢
- KL ∈ [0.01, 0.1]:健康更新区间
- KL > 0.2:策略漂移过大,需检查 β 或奖励设计
调试技巧:
- 同时记录
kl_k2和kl_k3,若二者差异显著,说明 π_θ 与 π_ref 已偏离较大 - 若 KL 持续上升且 reward 不增,考虑定期重置 reference model(DeepSeek-R1 实践)[[50]]
- 同时记录
log p
log_p 是模型对每个位置实际生成的 token 计算出的对数概率(Log Probability),属于标量值,而 hidden_state 是 Transformer 层输出的高维稠密向量。两者在计算链路、数据形态和用途上完全不同。
🔍 log_p 的完整计算链路
1 | Input Tokens |
📐 维度与形态对比
| 概念 | 形状 | 数据类型 | 物理含义 |
|---|---|---|---|
hidden_state |
[B, L, D] |
float32/16 | Transformer 输出的上下文表征向量 |
logits |
[B, L, V] |
float32/16 | 词表每个 token 的原始得分 |
log_p |
[B, L] |
float32/16 | 当前策略下,每个位置真实 token 的对数概率 |
💡
D通常为 4096/7680 等,V为词表大小(如 32k/128k),而log_p已坍缩到[B, L],只保留实际生成 token 的概率信息。
💻 代码级直观实现(PyTorch)
1 | import torch |
⚠️ 常见误区澄清
| 误区 | 正确理解 |
|---|---|
“log_p 就是 hidden_state” |
hidden_state 是向量表征;log_p 是经过 LM Head + LogSoftmax + Gather 后的标量概率 |
“推理时不需要 log_p” |
推理(生成)时模型会隐式计算它用于采样;RL 训练时需显式保存用于 loss |
“log_p 越大越好” |
仅表示模型对该 token 更自信;RL 中需与 reward/advantage 结合,盲目最大化会导致 mode collapse |
| “KL 直接用 hidden_state 算” | KL 是概率分布距离,必须基于 log_p;hidden_state 是特征空间,无法直接算分布散度 |
📌 总结
log_p= 每个位置真实 token 的对数概率,形状[B, L]- 由
hidden_state → logits → log_softmax → gather得到,不是 hidden_state 本身 - 在 GRPO 中用于:KL 正则、importance ratio、策略梯度、loss mask
- 训练时需同时保存
actor_logp和ref_logp,推理时通常不显式输出但底层会计算
训推一致性
LogP Diff vs KL 散度:本质区别与阈值含义
| 维度 | 🔹 LogP Diff (训推一致性) | 🔹 KL 散度 (RL 正则/监控) |
|---|---|---|
| 比较对象 | 同一模型,train mode vs infer mode | 两个策略,actor π_θ vs reference π_ref |
| 核心目标 | 验证数值计算一致性(精度对齐) | 控制策略更新幅度(防止分布崩溃) |
| 数学形式 | δ = |log_p^train - log_p^infer| |
KL = 𝔼[log(π_θ/π_ref)] 或其近似 |
| 是否取绝对值 | ✅ 是,关注偏差大小 | ❌ 否,保留方向信息(谁更自信) |
| 是否加权 | ❌ 所有 token 平等对待 | ✅ 高概率 token 贡献更大(p·log(p/q)) |
| 量纲/单位 | log 空间的相对误差(无单位比值) | 信息论距离(nat,无单位但数值意义不同) |
| 典型阈值 | rel_diff < 0.01 (1%) | KL < 0.01~0.1 (依任务/β系数调整) |
1 | # 在 trainer 的 logging 阶段 |
训推一致性指标
你的理解里有一个关键偏差:**RL 训练里计算 log p 时,训练端不是只输入 response,而是输入 prompt + rollout 生成的 response**。response 在这里有双重身份:
- 作为 label/action:要计算每个生成 token 的 log probability;
- 作为 teacher forcing 的上下文:第 (t) 个 response token 之后的 token,需要以前面已经生成的 response token 为条件。
所以训练和推理对齐的不是“API 传入的张量长得一样”,而是对齐同一个条件概率:
[
\log \pi_\theta(y_t \mid x, y_{<t})
]
其中:
- (x):prompt;
- (y_t):第 (t) 个生成出来的 response token;
- (y_{<t}):它之前已经生成的 response tokens。
推理时逐步得到这些概率;训练时用一次 causal forward 并行算出这些概率。数学上是等价的,只要 token、mask、position id、模型权重、logits 处理方式一致。
log p 到底对齐什么
设 prompt 为:
1 | x = [x1, x2] |
rollout 生成的 response 为:
1 | y = [y1, y2, y3] |
推理阶段并不是永远只输入 prompt。实际过程是:
1 | prefill([x1, x2]) -> log p(y1 | x1, x2) |
训练阶段会把同一条轨迹拼起来:
1 | input = [x1, x2, y1, y2, y3] |
然后 causal LM 的 logits 对齐关系是:
1 | logits at x2 -> predict y1 |
因此训练端算的是:
[
\log \pi_\theta(y_1 \mid x_1,x_2)
]
[
\log \pi_\theta(y_2 \mid x_1,x_2,y_1)
]
[
\log \pi_\theta(y_3 \mid x_1,x_2,y_1,y_2)
]
这和推理阶段逐步 decode 得到的 log p 是同一个东西。
需要注意的是,prompt token 通常不参与 RL loss,但它们必须作为上下文参与 attention。也就是说:
- attention mask 不能把 prompt 屏蔽掉;
- loss mask / response mask 只是在计算 loss 时忽略 prompt 部分。
如果训练端真的只输入 response,那么算出来的是:
[
\log \pi_\theta(y_t \mid y_{<t})
]
这当然无法和推理阶段的:
[
\log \pi_\theta(y_t \mid x, y_{<t})
]
对齐。这种流程就是错的。
为什么一次训练前向可以等价
decoder-only causal LM 的序列概率分解为:
[
\pi_\theta(y \mid x)
\prod_{t=1}^{T}
\pi_\theta(y_t \mid x, y_{<t})
]
推理阶段是一个因子一个因子地采样:
1 | p(y1 | x) |
训练阶段用 teacher forcing,把完整序列:
1 | [prompt, response] |
一次送入模型。由于 causal mask 的存在,每个位置只能看见自己左边的 token,不能看见未来 token,所以它可以并行计算所有条件概率。
以 token 序列:
1 | s = [x1, x2, y1, y2, y3] |
为例,模型输出 logits:
1 | z0, z1, z2, z3, z4 |
其中:
1 | z1 -> predict y1 |
最后一个 logits z4 是用来预测 y3 后面的下一个 token 的,通常不参与当前 response 的 log p 计算。
一个简化的 PyTorch 对齐逻辑如下:
1 | # prompt_ids: [m] |
这个 resp_logp 就应该和推理引擎在 rollout 时记录的 output token logprobs 对齐。
如果你使用 HuggingFace CausalLM 的 labels,常见做法是:
1 | labels = ids.clone() |
因为 HF CausalLM 内部通常会做 shift:用 logits[:, :-1] 预测 labels[:, 1:]。所以把原始 labels 中 prompt 部分置为 -100 后,第一个 response token 仍然会由最后一个 prompt token 的 logits 来预测。
prefill + decode 和 full forward 的关系
prefill + decode 不是另一种概率模型,它只是带 KV cache 的增量计算。
推理:
1 | prefill(prompt) 生成 prompt 的 KV cache |
训练 full forward:
1 | 一次性输入 [prompt, y1, y2, y3, ...] |
KV cache 的作用只是避免重复计算历史 token 的 key/value。理论上:
1 | full forward 的 prefix hidden states |
和:
1 | prefill + decode 逐步得到的 hidden states |
应该一致。
因此,只要以下条件一致,二者的 per-token log p 应该接近相等:
- 模型权重一致;
- tokenizer 和 token ids 一致;
- attention mask 一致;
- position ids 一致;
- RoPE scaling、YaRN、NTK 等位置编码配置一致;
- logits 处理方式一致;
- dtype 和 kernel 数值误差在可接受范围内。
实际工程中不一定 bitwise 相等,尤其是 bf16、FlashAttention、PagedAttention、tensor parallel、FP8 KV cache 等场景,但差异应该很小。若出现系统性大偏差,就说明存在训推不一致。
RL 崩溃为什么和 log p 不一致有关
以 PPO / GRPO 类算法为例,训练时通常需要 old logprob 和 new logprob:
[
r_t =
\exp(
\log \pi_{\theta_{\text{new}}}(y_t \mid x,y_{<t})
\log \pi_{\theta_{\text{old}}}(y_t \mid x,y_{<t})
)
]
其中:
old_logprob:rollout 时生成该 token 的策略概率;new_logprob:训练时当前 actor 对同一个 token 的概率;- (r_t):importance ratio。
如果刚同步完 actor,且还没有做 optimizer step,那么理论上:
1 | new_logprob ≈ old_logprob |
如果此时就出现较大偏差,比如:
1 | new_logprob - old_logprob |
大面积偏离 0,那么 PPO ratio 会被错误放大或缩小,clip、KL、advantage 加权都会失真,训练就可能崩溃。
所以这里说的“训推不一致”,通常不是指“训练输入 response,推理输入 prompt”这种概念差异,而是指:
1 | rollout/inference engine 记录的 log p |
和:
1 | training engine 对同一批 prompt+response 重算的 log p |
不一致。
常见的不一致来源
1. token 序列不一致
这是最常见的问题。
需要确保训练端使用的不是重新 detokenize 再 retokenize 的文本,而是 rollout 时真实生成的 token ids:
1 | prompt_token_ids + generated_response_token_ids |
常见坑包括:
- BOS 是否自动添加;
- chat template 是否一致;
- 是否使用
add_generation_prompt=True; - assistant header 是否被当成 prompt 还是 response;
- EOS token 是否包含在 response 中;
- stop string 和 stop token 的处理是否一致;
- 训练端是否重新拼文本后再 tokenize,导致边界 token 变化。
2. shift 和 mask 错位
response 第一个 token 的 logprob 来自最后一个 prompt token 的 logits。
如果 prompt 长度为 (m),response 长度为 (n),那么 response logprob 对应:
1 | logits[m-1 : m-1+n] |
而不是:
1 | logits[m : m+n] |
常见错误是 off-by-one。
3. logits processor 不一致
推理时可能使用:
- temperature;
- top-p;
- top-k;
- repetition penalty;
- min length;
- bad words;
- forced EOS;
- stop token suppression。
如果推理记录的是处理后的 logprob,而训练端用的是原始 logits 的 logprob,就会不一致。
例如 temperature 为 (\tau) 时,采样分布是:
[
\text{softmax}(z / \tau)
]
而不是:
[
\text{softmax}(z)
]
所以要么两边都用 raw logits,要么两边都应用相同的 temperature。调试时建议先关闭所有 processor:
1 | temperature = 1 |
4. padding、position id、packing 不一致
训练端常用 padding / sequence packing,推理端常用动态 batching / paged attention。需要确保:
- pad token 不参与 attention;
- prompt 不被 loss mask 误当成 attention mask 屏蔽;
- position ids 计算一致;
- packed sequence 之间不能互相 attention;
- RoPE scaling 配置一致;
- left padding / right padding 不导致 position id 差异。
5. 模型权重版本不一致
RL 系统通常有两个模型副本:
1 | rollout engine actor |
需要确认:
- 权重是否已经同步;
- LoRA adapter 是否一致;
- rollout engine 是否加载了 merge 后权重;
- tensor parallel 切分是否一致;
- 是否存在异步 rollout 的 stale policy;
- 量化权重、FP8 KV cache 是否引入较大误差。
6. 数值 kernel 差异
full forward 和 decode 可能使用不同 kernel:
- FlashAttention;
- PagedAttention;
- fused RMSNorm;
- fused softmax;
- bf16/fp16/fp32;
- tensor parallel all-reduce;
- vocab parallel softmax。
小误差正常,但大面积系统性误差不正常。
经验上,fp32/eager 模式可以非常接近;bf16、不同 attention kernel 下可能有 (10^{-3}) 到 (10^{-2}) 量级差异。这个不是硬标准,但如果偏差明显更大,或者 ratio 明显偏离 1,就要排查。
建议的排查流程
可以按下面顺序做最小化一致性测试:
- 固定一批 prompt。
- rollout engine 生成 response,并保存:
prompt_token_idsresponse_token_idsold_logprobsattention_maskposition_ids
- 关闭 temperature、top-p、top-k、repetition penalty 等 logits processor。
- 确认训练端和推理端使用同一份权重。
- 训练端用:
1 | input_ids = prompt_token_ids + response_token_ids |
重算 per-token logprob。
- 对齐 response 部分:
1 | train_logprob[i] 对齐 infer_logprob[i] |
其中:
1 | train_logprob[i] |
- 比较:
1 | delta = train_logprob - infer_logprob |
期望:
1 | delta ≈ 0 |
如果不满足,优先检查:
1 | token ids -> chat template -> shift -> mask -> logits processor -> position ids -> 权重同步 -> dtype/kernel |
一句话总结:
推理是逐步生成同一条 response,训练是对这条固定 response 做 teacher-forcing 评分;二者对齐的是每个 token 的条件 log probability,而不是表面上的输入 API。prefill+decode 与一次 full forward 在 causal mask 下理论等价,工程上的不一致才是需要排查的核心。